#!/usr/bin/env python3

#==============================================================================
# Functionality
#==============================================================================
import pdb
import sys
import os
import re
import json
import importlib
from pprint import pprint as pp
import PIL.Image
import numpy as np
import json

# utility funcs, classes, etc go here.

def asserting(cond):
    if not cond:
        pdb.set_trace()
    assert(cond)

def has_stdin():
    return not sys.stdin.isatty()

def reg(pat, flags=0):
    return re.compile(pat, re.VERBOSE | flags)


# https://stackoverflow.com/a/56090741
def import_path(path):
    module_name = os.path.basename(path).replace('-', '_')
    if module_name in sys.modules:
        return sys.modules[module_name]
    spec = importlib.util.spec_from_loader(
        module_name,
        importlib.machinery.SourceFileLoader(module_name, os.path.join(os.path.dirname(__file__), path))
    )
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    sys.modules[module_name] = module
    return module

def function(path, name=None):
    if callable(path):
        return path
    if path.startswith('lambda') and ':' in path:
        f = eval(compile(path, '<lambda>', mode='eval'))
        f.__name__ = path
        f.__qualname__ = f'function(' + repr(path) + ')'
        return f
    module = import_path(path)
    return getattr(module, module.__name__ if name is None else name)


#==============================================================================
# Functionality
#==============================================================================

import webcolors
import cv2 # pip3 install opencv-python
import imgcat

def hex2rgb(color):
  if isinstance(color, (tuple, list)):
    return color
  if not color.startswith('#'):
    try:
      color = webcolors.name_to_hex(color)
    except ValueError:
      pass
  color = color.lstrip('#')
  if len(color) <= 4:
    return tuple(int((int(x, 16) / 15.0) * 255.0) for x in color)
  else:
    return tuple(int(a+b, 16) for a, b in zip(color[::2], color[1::2]))

def rgb2hex(color):
  if isinstance(color, str):
    if color.startswith('#'):
      return color
    return webcolors.name_to_hex(color)
  try:
    color = hex2rgb(webcolors.rgb_to_name(color))
  except ValueError:
    pass
  return "#{0:02x}{1:02x}{2:02x}".format(*color)

def imgat(image, x, y):
  H, W, *_ = image.shape
  if x < 0 or x >= W or y < 0 or y >= H:
    return np.zeros_like(image[0,0])
  return image[y, x]

def transparent(image, x, y):
  pixel = imgat(image, x, y)
  assert np.shape(pixel) == (4,)
  return pixel[-1] == 0

def expand_rectangle_predicate(image, coords, predicate):
  H, W, *_ = np.shape(image)
  x, y, w, h = coords
  while x > 0 and predicate(image, x, y):
    x -= 1
    w += 1
  while y > 0 and predicate(image, x, y):
    y -= 1
    h += 1
  while w < W and predicate(image, x+w, y):
    w += 1
  while h < H and predicate(image, x, y+h):
    h += 1
  return x, y, w, h

def to_two_tuple(x):
  if isinstance(x, (list, tuple)):
    assert len(x) > 0 and len(x) <= 2
    if len(x) == 1:
      return (x[0], x[0])
    return tuple(x)
  return (x, x)
      

def shift_shape(shape, by_x, by_y):
  W, H, *_ = shape
  wid, hgt = np.arange(W), np.arange(H)
  # wid = np.arange(1, max(1, W + 1)) - 1
  # hgt = np.arange(1, max(1, H - 1)) - 1
  wid = wid + by_x
  hgt = hgt + by_y
  wid = wid[wid >= 0]
  wid = wid[wid < W]
  hgt = hgt[hgt >= 0]
  hgt = hgt[hgt < H]
  return wid, hgt

def shift(image, by_x, by_y):
  image = np.roll(image, by_y, 0)
  image = np.roll(image, by_x, 1)
  return image
  dst = np.zeros_like(image)
  dst[shift_shape(dst.shape, -by_x, -by_y)] = image[shift_shape(image.shape, by_x, by_y)]
  return dst
  

def diffs(image, pred):
  horiz = shift(image, 0, 0) - shift(image, -1, 1)
  vert = shift(image, -1, -1) - shift(image,  1, -1)
  


def shrink(coords, by_x, by_y):
  x, y, w, h = coords
  d_left, d_right = to_two_tuple(by_x)
  d_top, d_bot = to_two_tuple(by_y)
  x += d_left
  w -= d_left
  y += d_top
  h -= d_top
  w -= d_right
  h -= d_bot
  #return max(0, x+d_left), max(0, y+d_top), max(1, w - d_right), max(1, h - d_bot)
  return max(0, x), max(0, y), max(1, w), max(1, h)

def expand(coords, by_x, by_y):
  d_left, d_right = to_two_tuple(by_x)
  d_top, d_bot = to_two_tuple(by_y)
  return shrink(coords, (-d_left, -d_right), (-d_top, -d_bot))

def expand_rectangle(image, coords):
  coords = shrink(coords, -1, -1)
  coords = expand_rectangle_predicate(image, coords, lambda image, x, y: not transparent(image, x, y))
  # coords = expand_rectangle_predicate(image, coords, lambda image, x, y: transparent(image, x, y))
  return coords


def gridlike(image):
  H, W, *_ = np.shape(image)
  xy = np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing='ij'), -1)
  return xy

def crisp(image):
  if image.shape[-1] == 4:
    mask = (image[:, :, 3] == 0)
  else:
    mask = None
  # Load image, grayscale, median blur, sharpen image
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  blur = cv2.medianBlur(gray, 5)
  sharpen_kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
  sharpen = cv2.filter2D(blur, -1, sharpen_kernel)
  # Threshold and morph close
  thresh = cv2.threshold(sharpen, 160, 255, cv2.THRESH_BINARY_INV)[1]
  if mask is not None:
    thresh[mask] = 0
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
  close = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=2)
  return dict(thresh=thresh, close=close, sharpen=sharpen, blur=blur, gray=gray, mask=mask )

# https://stackoverflow.com/a/57193144/9919772

# def detect_rectangles(image):
#   if image.shape[-1] == 4:
#     mask = (image[:, :, 3] == 0)
#   else:
#     mask = None
#   # Load image, grayscale, median blur, sharpen image
#   gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#   blur = cv2.medianBlur(gray, 5)
#   sharpen_kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
#   sharpen = cv2.filter2D(blur, -1, sharpen_kernel)
#   # Threshold and morph close
#   thresh = cv2.threshold(sharpen, 160, 255, cv2.THRESH_BINARY_INV)[1]
#   if mask is not None:
#     thresh[mask] = 0
#   kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
#   close = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=2)
#   # Find contours and filter using threshold area
#   cnts = cv2.findContours(close, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
#   cnts = cnts[0] if len(cnts) == 2 else cnts[1]
#   min_area = 100
#   max_area = 1500
#   image_number = 0
#   for c in cnts:
#     x,y,w,h = cv2.boundingRect(c)
#     yield (x, y, w, h), c
#     # print((x, y), (w, h))
#     # imgcat.imgcat(image[y:y+h, x:x+w])
#     # ROI = image[y:y+h, x:x+w]
#     # cv2.imwrite('ROI_{}.png'.format(image_number), ROI)
#     # cv2.rectangle(image, (x, y), (x + w, y + h), (36,255,12), 2)
#     # image_number += 1


def region_mask(image):
  mask = (255 * (1 - crisp(image)['mask'])).astype('uint8')
  mask[mask > 0] = 255
  return 255 - mask

def detect_regions(image):
  cnts, hierarchy = cv2.findContours(image, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  for c in cnts:
    rect = cv2.boundingRect(c)
    yield rect, c

def entirely_transparent(image, rect):
  x, y, w, h = rect
  return ((image[y:y+h, x:x+w, 3] == 0).all())

def replace_transparency(image, color, opacity=255):
  color = hex2rgb(color)
  if len(color) < 4:
    color = color + (opacity,)
  mask = image[:, :, 3] == 0
  dst = image.copy()
  dst[mask] = color
  return dst

def detect_sprites(image, mask=None):
  if mask is None:
    mask = region_mask(image)
  for rect, cnt in detect_regions(mask):
    x, y, w, h = rect
    area = cv2.contourArea(cnt)
    if area == (w-1) * (h - 1) or args.all:
      yield rect


def left(box):
  x, y, w, h = box
  return x

def right(box):
  x, y, w, h = box
  return x + w

def top(box):
  x, y, w, h = box
  return y

def bottom(box):
  x, y, w, h = box
  return y+h

def width(box):
  if isinstance(box, np.ndarray):
    assert len(np.shape(box)) in [2, 3]
    return np.shape(box)[1]
  if hasattr(box, 'shape'):
    assert len(box.shape) in [2, 3]
    return box.shape[1]
  if hasattr(box, 'width'):
    return box.width
  x, y, w, h = box
  return w

def height(box):
  if isinstance(box, np.ndarray):
    assert len(np.shape(box)) in [2, 3]
    return np.shape(box)[0]
  if hasattr(box, 'shape'):
    assert len(box.shape) in [2, 3]
    return box.shape[0]
  if hasattr(box, 'height'):
    return box.height
  x, y, w, h = box
  return h

def v2sub(box1, box2): return (left(box1) - left(box2), top(box1) - top(box2))
def v2add(box1, box2): return (left(box1) + left(box2), top(box1) + top(box2))

def collides(box1, box2):
  w, h = v2sub(box1, box2)
  # if abs(w) <= width(box1) + width(box2):
  #   return True
  # if abs(h) <= height(box1) + height(box2):
  #   return True
  # return False
  if right(box1) < left(box2):
    return False # 'right < left'
  if right(box2) < left(box1):
    return False # 'right(box2) < left(box1)'
  if top(box1) > bottom(box2):
    return False # 'top(box1) > bottom(box2)'
  if top(box2) > bottom(box1):
    return False # 'top(box2) > bottom(box1)'
  return True

def carve(img, coords):
  x, y, w, h = coords
  return img[y:y+h, x:x+w]

def show(img, coords):
  it = carve(img, coords)
  if sys.stdout.isatty():
    print('----')
    print(coords)
    imgcat.imgcat(it)
    sys.stdout.flush()
  #return it

def include(box1, box2):
  x = min(left(box1), left(box2))
  y = min(top(box1), top(box2))
  w = max(right(box1), right(box2)) - x
  h = max(bottom(box1), bottom(box2)) - y
  return x, y, w, h

def coalesce(spans, n=3):
  spans = set(spans)
  if spans:
    seen = set()
    for span1 in spans:
      for span2 in spans:
        if span1 == span2 or (span1, span2) in seen or (span2, span1) in seen:
          continue
        if collides(shrink(span1, n, n), shrink(span2, n, n)):
          bigger = include(span1, span2)
          spans2 = set(spans)
          spans2.remove(span1)
          spans2.remove(span2)
          spans2.add(bigger)
          return coalesce(spans2)
        seen.add((span1, span2))
  return spans
          
    

def unzip2(xys):
  """Unzip sequence of length-2 tuples into two tuples."""
  xs = []
  ys = []
  for x, y in xys:
    xs.append(x)
    ys.append(y)
  return tuple(xs), tuple(ys)


def bounding_boxes(image):
  for rect in detect_sprites(image):
    if not entirely_transparent(image, rect):
      yield rect

def spritesheet_open(img):
  if isinstance(img, str):
    img = PIL.Image.open(img)
  img = np.asarray(img)
  assert len(img.shape) >= 2
  if len(img.shape) == 2:
    # greyscale image
    img = img[:, :, None]
  H, W, C = img.shape
  if C == 1:
    img = np.concatenate([img, img, img], axis=-1)
    C = 3
  if C < 4:
    assert C == 3, "Expected RGB image"
    img = np.concatenate([img, np.ones((H, W, 1), dtype=int) * 255], axis=-1)
    C = 4
  img = np.clip(img, 0, 255)
  img = img.astype('uint8')
  return img

# def spritemask(img, color):
#   R, G, B, *A = hex2rgb(color)
#   if not A:
#     A = (255,)
#   pixel = np.array([R, G, B, *A], dtype=int)
#   mask = (img[:, :] == pixel).all(-1)
#   indices = np.argwhere(mask)
#   return indices, mask, pixel
#   coords = list(zip(indices[:, 0][::4], indices[:, 1][::4]))
#   return coords

def area(rect):
  x, y, w, h = rect
  return (w-1)*(h-1)

def cmp_rect(rect1, rect2):
  if top(rect1) != top(rect2):
    return top(rect1) - top(rect2)
  return left(rect1) - left(rect2)

import functools

def sort_rects(rects):
  return sorted(rects, key=functools.cmp_to_key(cmp_rect))
  rects = sorted(rects, key=lambda rect: int(area(rect)/(64*64)))
  rects = sorted(rects, key=lambda rect: width(rect))
  rects = sorted(rects, key=lambda rect: height(rect))
  rects = sorted(rects, key=lambda rect: top(rect))
  rects = sorted(rects, key=lambda rect: left(rect))
  return rects

def spritesheet2json(filename):
  """TODO"""
  img = spritesheet_open(filename)
  magenta = replace_transparency(img, 'magenta', 255)
  rects = list(bounding_boxes(img))
  for rect in sort_rects(rects):
    show(magenta, rect)
  print(json.dumps(dict(filename=filename, sprites=rects)), flush=True)
    #print('----')
    #print(json.dumps(rect))
  # breakpoint()
  # return None
  # return coords

#==============================================================================
# Cmdline
#==============================================================================
import argparse

def get_parser(parser=None):
    if parser is None:
        parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, 
            description=spritesheet2json.__doc__.strip())

    parser.add_argument('-v', '--verbose',
        action="store_true",
        help="verbose output" )

    parser.add_argument('-a', '--all',
        action="store_true",
        help="" )

    # parser.add_argument('-c', '--color', default="#000000",
    #     help="Color of frame border (default #000000)" )

    parser.add_argument('-0', '--print0',
        action="store_true",
        help="Prints \\0 after each result rather than newline" )
    return parser

args = None

#==============================================================================
# Main
#==============================================================================

def run():
    if args.verbose:
        print(args, file=sys.stderr)
    if len(args.args) <= 0 and not has_stdin():
        # if there were no args and there was no input, prompt user.
        print('Enter input (press Ctrl-D when done):')
    if len(args.args) <= 0 or has_stdin():
        indata = sys.stdin.read()
        args.args.extend(indata.split('\0') if '\0' in indata else indata.splitlines())
    # for each arg on cmdline...
    for arg in args.args:
        spritesheet2json(arg)
        # if not isinstance(result, (tuple, list)):
        #   result = [result]
        # for result in result:
        #   print(result, end='\0' if args.print0 else '\n')

def main():
    try:
        global args
        if not args:
            args, leftovers = get_parser().parse_known_args()
            args.args = leftovers
        return run()
    except IOError as e:
        # http://stackoverflow.com/questions/15793886/how-to-avoid-a-broken-pipe-error-when-printing-a-large-amount-of-formatted-data
        if e.errno != 32:
            raise
        try:
            sys.stdout.close()
        except IOError:
            pass
        try:
            sys.stderr.close()
        except IOError:
            pass

if __name__ == "__main__":
    main()

